from scipy import stats


class EarlyStopping(object):
    """Early stop if a span of newest records are not better than the current best record.

    Args:
        patience (int): The span of checked newest records.

    Attributes:
        __record_list (list): List of records.
        __best (float): The current best record.
        __patience (int): The span of checked newest records.
        __p (int): The number of newest records that are worse than the current best record.

    """
    def __init__(self, patience):
        self.__record_list = []
        self.__best = None
        self.__patience = patience
        self.__p = 0

    def stop(self, new_value):
        """Append the new record to the record list
        and check if the number of new records than are worse than the best records exceeds the limit.

        Args:
            new_value (float): The new record generated by the newest model.

        Returns:
            bool: ``True`` if the number of new records than are worse than the best records exceeds the limit and
            triggers early stop, otherwise ``False``.
        """
        self.__record_list.append(new_value)
        if self.__best is None or new_value < self.__best:
            self.__best = new_value
            self.__p = 0
            return False
        else:
            if self.__p < self.__patience:
                self.__p += 1
                return False
            else:
                return True


class EarlyStoppingTTest(object):
    """Early Stop by t-test.

    T-test is a two-sided test for the null hypothesis that 2 independent samples
    have identical average (expected) values. This method takes two intervals according to ``length``
    in the record list and see if they have identical average values. If so, do early stop.

    Args:
        length (int): The length of checked interval.
        p_value_threshold (float): The p-value threshold to decide whether to do early stop.

    Attributes:
        __record_list (list): List of records.
        __best (float): The current best record.
        __test_length (int): The length of checked interval.
        __p_value_threshold (float): The p-value threshold to decide whether to do early stop.
    """
    def __init__(self, length, p_value_threshold):
        self.__record_list = []
        self.__best = None
        self.__test_length = length
        self.__p_value_threshold = p_value_threshold

    def stop(self, new_value):
        """
        Take two intervals in the record list to do t-test.

        Args:
            new_value (float): The new record generated by the newest model.

        Returns:
            bool: ``True`` if p value of t-test is smaller than threshold and
            triggers early stop, otherwise ``False``.
        """
        self.__record_list.append(new_value)
        if len(self.__record_list) >= (self.__test_length * 2):
            lossTTest = stats.ttest_ind(self.__record_list[-self.__test_length:],
                                        self.__record_list[-self.__test_length * 2:-self.__test_length], equal_var=False)
            ttest = lossTTest[0]
            pValue = lossTTest[1]
            print('ttest:', ttest, 'pValue', pValue)
            if pValue > self.__p_value_threshold or ttest > 0:
                return True
            else:
                return False
        else:
            return False
